In [2]:
# Standard plotting setup
%matplotlib inline
import matplotlib.pyplot as plt
from pylab import rcParams
rcParams['figure.figsize'] = 10, 5
plt.style.use('ggplot')
Let's model the Bernouilli multi-armed bandit. The Bernoulli MBA is an $N$-armed bandit where each arm gives binary rewards according to some probability:
$r_i \sim Bernouilli(\mu_i)$
Here $i$ is the index of the arm. Let's model this as a Markov decision process. The state is going to be defined as:
$s(t) = (\alpha_1, \beta_1, \ldots, \alpha_N, \beta_N, r_t)$
$\alpha_i$ is the number of successes encountered so far when pulling arm $i$. $\beta_i$ is, similarly, the number of failures encountered when pulling that arm. $r_t$ is the reward, either 0 or 1, from the last trial.
Assuming a uniform prior on $\mu_i$, the posterior distribution of the $\mu_i$ in a given state are:
$p(\mu_i|s(t)) = Beta(\alpha_i+1,\beta_i+1)$
When we're in a given state, we have the choice of performing one of $N$ actions, corresponding to pulling each of the arms. Let's call pulling the $i$'th arm $a_i$. This will put us in a new state, with a certain probability. The new state will be same for arms not equal to i. For the $i$'th arm, we have:
$s(t+1) = (\ldots \alpha_i + 1, \beta_i \ldots 1)$ with probability $(\alpha_i+1)/(\alpha_i+\beta_i+2)$
$s(t+1) = (\ldots \alpha_i, \beta_i + 1 \ldots 0)$ with probability $(\beta_i+1)/(\alpha_i+\beta_i+2)$
We can solve exactly for this MDP, e.g. using value iteration, given that it's small enough. For $M$ trials, the state space is $M^{2N}$ - it's possible to solve the 2-armed bandit for 10-20 trials this way, but it grows exponentially fast.
In [3]:
import itertools
import numpy as np
from pprint import pprint
def sorted_values(dict_):
return [dict_[x] for x in sorted(dict_)]
def solve_bmab_value_iteration(N_arms, M_trials, gamma=1,
max_iter=10, conv_crit = .01):
util = {}
# Initialize every state to utility 0.
state_ranges = [range(M_trials+1) for x in range(N_arms*2)]
# The reward state
state_ranges.append(range(2))
for state in itertools.product(*state_ranges):
# Some states are impossible to reach.
if sum(state[:-1]) > M_trials:
# A state with the total of alphas and betas greater than
# the number of trials.
continue
if sum(state[:-1:2]) == 0 and state[-1] == 1:
# A state with a reward but alphas all equal to 0.
continue
if sum(state[:-1:2]) == M_trials and state[-1] == 0:
# A state with no reward but alphas adding up to M_trials.
continue
if sum(state[:-1]) == 1 and sum(state[:-1:2]) == 1 and state[-1] == 0:
# A state with an initial reward according to alphas but not according
# the reward index
continue
util[state] = 0
# Main loop.
converged = False
new_util = util.copy()
opt_actions = {}
for j in range(max_iter):
# Line 5 of value iteration
for state in util.keys():
reward = state[-1]
# Terminal state.
if sum(state[:-1]) == M_trials:
new_util[state] = reward
continue
values = np.zeros(N_arms)
# Consider every action
for i in range(N_arms):
# Successes and failure for this state.
alpha = state[i*2]
beta = state[i*2+1]
# Two possible outcomes: either that arm gets rewarded,
# or not.
# Transition to unrewarded state:
state0 = list(state)
state0[-1] = 0
state0[2*i+1] += 1
state0 = tuple(state0)
# The probability that we'll transition to this unrewarded state.
p_state0 = (beta + 1) / float(alpha + beta + 2)
# Rewarded state.
state1 = list(state)
state1[-1] = 1
state1[2*i] += 1
state1 = tuple(state1)
p_state1 = 1 - p_state0
try:
value = gamma*(util[state0]*p_state0 +
util[state1]*p_state1)
except KeyError,e:
print state
print state0
print state1
raise e
#print state0, util[state0], p_state0
#print state1, util[state1], p_state1
values[i] = value
#print state, values, reward
new_util[state] = reward + np.max(values)
opt_actions[state] = np.argmax(values)
# Consider the difference between the new util
# and the old util.
max_diff = np.max(abs(np.array(sorted_values(util)) - np.array(sorted_values(new_util))))
util = new_util.copy()
print "Iteration %d, max diff = %.5f" % (j, max_diff)
if max_diff < conv_crit:
converged = True
break
#pprint(util)
if converged:
print "Converged after %d iterations" % j
else:
print "Not converged after %d iterations" % max_iter
return util, opt_actions
util, opt_actions = solve_bmab_value_iteration(2, 2, max_iter=5)
In [4]:
opt_actions
Out[4]:
For the 2-armed, 2-trial Bernoulli bandit, the strategy is simple: pick the first arm. If it rewards, then pick it again. If not, pick the other. Note that this is the same as most sensible strategies, for instance greedy or UCB.
In [5]:
util
Out[5]:
Note that the utility of the root node is 1.08 - what does that mean? If we get rewarded in the initial trial, that means that the posterior for the mean of that arm is .67. OTOH, when we fail on the first trial, we can still pick the other arm, which still has a posterior mean of .5. Thus, we have rewards:
That means the expected total reward is:
In [28]:
2*.5*2.0/3.0 + .5/3.0 + .5*.5
Out[28]:
And that's what utility means in this context. Let's see about the 3-trial 2-armed bandit:
In [6]:
util, opt_actions = solve_bmab_value_iteration(2, 3, max_iter=5)
opt_actions
Out[6]:
The optimal strategy goes: pick arm 0. If it rewards, pick it again for the next 2 trials. If it doesn't reward, then pick arm 1. If that rewards, keep that one. If it doesn't, pick 0 again.
Let's see with 4:
In [7]:
util, opt_actions = solve_bmab_value_iteration(2, 4, max_iter=6)
What's interesting here is that value iteration always converges in M_trials + 1 iterations - information only travels backwards through time - much as in Viterbi in the context of HMMs. If we're only interested in the next best action given the current state, it might be possible to iterate backwards through time, starting from the terminal states, throwing away the latest data as we go along. Before we get into premature optimization, however, let's see how far we can look ahead without crashing with Chromebook.
In [8]:
M_trials = 16
%time util, opt_actions = solve_bmab_value_iteration(2, M_trials, max_iter=M_trials+2)
In [9]:
M_trials
bad_keys = [k for k in opt_actions.keys() if sum(k[:-1]) > 15]
assert(len(bad_keys) == 0)
It seems like my Chromebook can look ahead at least sixteen steps into the future without dying - pretty good!
In [10]:
# Create a design matrix related to the optimal strategies.
X = []
y = []
seen_keys = {}
for key, val in opt_actions.iteritems():
if key[:-1] in seen_keys:
# We've already seen this, continue.
continue
alpha0 = float(key[0] + 1)
beta0 = float(key[1] + 1)
alpha1 = float(key[2] + 1)
beta1 = float(key[3] + 1)
if alpha0 == alpha1 and beta0 == beta1:
# We're in a perfectly symmetric situtation, skip this then.
continue
seen_keys = key[:-1]
# Standard results for the Beta distribution.
# https://en.wikipedia.org/wiki/Beta_distribution
mean0 = alpha0/(alpha0 + beta0)
mean1 = alpha1/(alpha1 + beta1)
std0 = np.sqrt(alpha0*beta0 / (alpha0 + beta0 + 1)) / (alpha0 + beta0)
std1 = np.sqrt(alpha1*beta1 / (alpha1 + beta1 + 1)) / (alpha1 + beta1)
t = alpha0 + beta0 + alpha1 + beta1
X.append([mean0,mean1,std0,std1,t,1,alpha0 - 1,beta0 - 1,alpha1 - 1,beta1 - 1])
y.append(val)
X = np.array(X)
y = np.array(y)
Let's train a supervised network a see how well it can predict the correct move based on a purely greedy heuristic - and based on a heuristic which takes into account the uncertainty in the estimate.
In [11]:
from sklearn.linear_model import LogisticRegression
the_model = LogisticRegression(C=100.0)
X_ = X[:,:2]
the_model.fit(X_,y)
y_pred = the_model.predict(X_)
print ("Greedy: %.4f%% of moves are incorrect" % ((np.mean(abs(y_pred-y)))*100))
print the_model.coef_
the_model = LogisticRegression(C=100.0)
X_ = X[:,:4]
the_model.fit(X_,y)
y_pred = the_model.predict(X_)
print ("UCB: %.4f%% of moves are incorrect" % ((np.mean(abs(y_pred-y)))*100))
print the_model.coef_
the_model = LogisticRegression(C=100000.0)
X_ = X[:,:4]
X_ = np.hstack((X_,(X[:,4]).reshape((-1,1))*X[:,2:4]))
the_model.fit(X_,y)
y_pred = the_model.predict(X_)
print ("UCB X time: %.4f%% of moves are incorrect" % ((np.mean(abs(y_pred-y)))*100))
print the_model.coef_
We see that the greedy strategy misses the right move 3% of the time, while UCB shaves that down to 1.8%. Pretty significant. The UCB parameter - a parameter which determines how much "bonus" should be given to uncertainty - is suspiciously low at (29.49 / 57.7 ~= .5). In the literature, people use something around 2-3.
Adding a parameter which is the cross of time and the standard deviation of the estimate reveals the source of this discrepancy: at the initial time point, the UCB parameter is high (496.7 / 201 ~ 2.5) and it ramps down linearly as a function of time to (496 - 26.26*16) / 200 ~= 0.4. Thus, the optimal strategy is similar to a UCB strategy, with a twist: the exploration bonus should ramp down as a function of time. This makes sense: new information is more valuable in the initial trials.
This UCB X time strategy misses only .5% of moves, which is quite good, all things considered.
The dynamic programming approach is of theoretical interest, but it doesn't scale well to other kinds of problems, like contextual bandits, problems with continuous-valued rewards, or problems with larger state spaces. Rather than exhaustively determining the outcome of every path, we can sample outcomes at random. Let's start by implement vanilla MCTS, where actions at every junction are sampled uniformly. We'll later upgrade to UCT.
In [ ]:
import collections
class VanillaMCTS:
def __init__(self, N_arms, M_trials, stochastic_rewards = True):
self.N_arms = N_arms
self.M_trials = M_trials
self.stochastic_rewards = stochastic_rewards
self.state_action_reward = collections.defaultdict(int)
self.state_action_visit = collections.defaultdict(int)
def find_best_action(self,
current_state,
max_horizon = 10,
max_samples = 100):
max_depth = min(self.M_trials - sum(current_state[::-1]), max_horizon)
for n in range(max_samples):
self.mcts_search(current_state, max_depth)
state_rewards = [(self.state_action_reward[state_action] /
float(self.state_action_visit[state_action]),state_action) for
state_action in self.state_action_reward.keys() if
state_action[0] == current_state]
max_reward, best_state_action = max(state_rewards)
_, best_action = best_state_action
return max_reward, best_action, state_rewards
def mcts_search(self, state, max_depth):
if max_depth == 0:
return 0
# Select an action
action = self.select_action(state)
# Pull that arm
next_state, expected_reward = self.perform_action(state, action)
if self.stochastic_rewards:
r = next_state[-1]
else:
r = expected_reward
reward = self.mcts_search(next_state, max_depth - 1) + r
# Memo-ize
state_action = (state,action)
self.state_action_reward[state_action] += reward
self.state_action_visit[state_action] += 1
return reward
def perform_action(self, state, action):
# Pull the arm in question; only valid for Bernouilli arm
alpha = state[action*2]
beta = state[action*2 + 1]
expected_reward = (alpha + 1)/float(alpha + beta + 2)
rewarded = np.random.rand() < expected_reward
state = list(state)
if rewarded:
state[action*2] += 1
state[-1] = 1
else:
state[action*2+1] += 1
state[-1] = 0
return (tuple(state), expected_reward)
def select_action(self, state):
# Select uniformily at random
return int(np.random.rand() * (len(state) / 2))
In [32]:
# Action 1 is better both from an exploration and an exploitation
# perspective, but not by that much
ambiguous_state = (2,2,2,1,0)
ndraws = 10
nsims_per = 1000
best_actions = np.zeros((nsims_per,ndraws,2))
max_rewards = np.zeros((nsims_per,2))
for k,stochastic_rewards in enumerate([False, True]):
for j in range(ndraws):
mcts = VanillaMCTS(2, 16, stochastic_rewards = stochastic_rewards)
for i in range(nsims_per):
_, best_action, state_reward = mcts.find_best_action(ambiguous_state, max_samples = 1)
best_actions[i,j,k] = best_action
plt.subplot(121)
plt.plot(best_actions[:,:,0].mean(1))
plt.xlabel('# draws')
plt.ylabel('p(picking) right arm after N draws')
plt.ylim([0,1])
plt.title('deterministic rewards')
plt.subplot(122)
plt.plot(best_actions[:,:,1].mean(1))
plt.xlabel('#draws')
plt.title('stochastic rewards')
plt.ylim([0,1])
Out[32]:
Generally, the probability of picking the right arm increases with the number of draws -- but it still takes several hundred trials to pick the right arm consistently, despite the fact that we're simulating a rather large difference. This gets better if we return deterministic rewards rather than stochastic. Let's see what happens with a harder case:
In [33]:
# Action 1 is better only from an exploitation perspective
ambiguous_state = (2,2,1,1,0)
ndraws = 10
nsims_per = 10000
best_actions = np.zeros((nsims_per,ndraws,2))
max_rewards = np.zeros((nsims_per,2))
for k,stochastic_rewards in enumerate([False, True]):
for j in range(ndraws):
mcts = VanillaMCTS(2, 16, stochastic_rewards = stochastic_rewards)
for i in range(nsims_per):
_, best_action, state_reward = mcts.find_best_action(ambiguous_state, max_samples = 1)
best_actions[i,j,k] = best_action
plt.subplot(121)
plt.plot(best_actions[:,:,0].mean(1))
plt.xlabel('# draws')
plt.ylabel('p(picking) right arm after N draws')
plt.ylim([0,1])
plt.title('deterministic rewards')
plt.subplot(122)
plt.plot(best_actions[:,:,1].mean(1))
plt.xlabel('#draws')
plt.title('stochastic rewards')
plt.ylim([0,1])
Out[33]:
In this case, the method does not converge to the correct action -- subsequent actions to the first are selected uniformily -- the method answers the question, which action should I pick next, if I picking actions at random later? The answer, in this case, is that it doesn't matter
In [184]:
%matplotlib inline
import matplotlib.pyplot as plt
plt.subplot(121)
plt.plot(max_rewards)
plt.subplot(122)
plt.plot(best_actions)
best_actions
Out[184]:
instead focus on the most promising paths. Let's do this via MCTS - the initial value of a branch is going to be set as the expected outcome plus an exploration bonus, and after the horizon of the tree search is reached, the path will be scored by its mean value. We'll compare this to a UCB strategy.
In [119]:
max_reward
best_action
Out[119]:
In [ ]:
# Find a case where the greedy strategy is incorrect
the_model = LogisticRegression(C=100.0)
X_ = X[:,:2]
the_model.fit(X_,y)
y_pred = the_model.predict(X_)
print X[np.where(y_pred != y)[0][0],:]
A few things to try: